# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
from torch.nn.parameter import Parameter
import logging
import os
import argparse
import numpy as np
import math
from copy import deepcopy
from torch.autograd import grad
from scipy.optimize import linear_sum_assignment
from scipy.fftpack import dct, idct
from torch.utils.data import Dataset
from PIL import Image
import time
import thop
from thop import profile
import matplotlib.pyplot as plt
import lpips
from torchmetrics.image import StructuralSimilarityIndexMeasure

device = torch.device("cuda:0")  

leak_mode_value = "batch-1-gauss-0.0"

def set_logger(name='exp', filepath=None, level='INFO'):
    formatter = logging.Formatter(
        fmt="[%(asctime)s %(name)s] %(message)s",
        datefmt="%y-%m-%d %H:%M:%S")
    logger = logging.getLogger(name)
    logger.setLevel(logging.getLevelName(level))
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    if filepath is not None:
        if os.path.dirname(filepath) != '':
            if not os.path.isdir(os.path.dirname(filepath)):
                os.makedirs(os.path.dirname(filepath))
        file_handle = logging.FileHandler(filename=filepath, mode="a")
        file_handle.set_name("file")
        file_handle.setFormatter(formatter)
        logger.addHandler(file_handle)
    return logger

parser = argparse.ArgumentParser(description='Gradient Inversion Transcript.')
parser.add_argument('--dataset', type=str, default="ImageNet",
                    help='dataset to do the experiment')
parser.add_argument('--model', type=str, default="GIT-3000",
                    help='GIT')
parser.add_argument('--shared_model', type=str, default="ResNet",
                    help='leaked_model')
parser.add_argument('--lr', type=float, default=1e-4,
                    help='learning rate')
parser.add_argument('--epochs', type=int, default=50,
                    help='epochs for training')
parser.add_argument('--batch_size', type=int, default=256,
                    help='batch_size for training the MLP')
parser.add_argument('--leak_mode', type=str, default=leak_mode_value,
                    help='batch-{batch_size of leak model}/gauss-{noise rate}')
parser.add_argument('--trainset', type=str, default="full")
parser.add_argument('--base_dir', type=str, default="data",
                    help='base directory to save results')
parser.add_argument('--num_layers', type=int, default="5",
                    help='number_of_layers')
parser.add_argument('--percent', type=int, default="1",
                    help='/percent of training set')
args = parser.parse_args()
logger = set_logger("",
                    f"{args.base_dir}/{args.dataset}_{args.shared_model}_{args.num_layers}_{args.model}_{args.leak_mode}_{args.lr}_{args.epochs}_{args.batch_size}.txt")
logger.info(args)

print("Arguments parsed successfully.")


def weights_init(m):
    if hasattr(m, "weight") and m.weight is not None:
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias") and m.bias is not None:
        m.bias.data.uniform_(-0.5, 0.5)


class LeNet(nn.Module):
    def __init__(self, num_classes=10, num_layers=3):
        super(LeNet, self).__init__()
        act = nn.ReLU
        layers = []
        in_channels = 3
        for i in range(num_layers):
            # conv_layer = nn.Conv2d(in_channels, 12, kernel_size=5, padding=5//2, stride=2 if i < num_layers - 1 else 1, bias=False)
            conv_layer = nn.Conv2d(in_channels, 12, kernel_size=2, bias=False, padding="same")
            # conv_layer = nn.Conv2d(in_channels, 12, kernel_size=3, stride=2, padding=1, bias=False)
            layers.append(conv_layer)
            layers.append(act())
            in_channels = 12
        self.body = nn.Sequential(*layers)
        self.flatten = nn.Flatten()
        self.fc = None  # Define fully connected layer dynamically later

    def forward(self, x):
        out = self.body(x)
        out = self.flatten(out)
        # Only define self.fc the first time forward is called
        if self.fc is None:
            num_features = out.size(1)  # Dynamically calculate flattened feature dimension
            self.fc = nn.Linear(num_features, num_classes)  # Dynamically create fully connected layer
            # Move the fully connected layer to the same device as the input tensor
            self.fc.to(x.device)
        out = self.fc(out)
        return out


class LeNetMnist(nn.Module):
    def __init__(self, input_channels=1, num_classes=10):
        super(LeNetMnist, self).__init__()
        act = nn.Sigmoid
        self.body = nn.Sequential(
            nn.Conv2d(input_channels, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2),
            act(),
            nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1),
            act(),
        )
        self.fc = nn.Sequential(
            nn.Linear(588, num_classes)
        )

    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        # First convolutional layer
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Second convolutional layer
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=stride * 2,
                    bias=False
                )
            )
            # only when (H,W) >1, do BN for shortcut
            # if out_channels > 1:
            #     self.shortcut.add_module("bn", nn.BatchNorm2d(out_channels))

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out) if out.shape[-1] >1 else out # Only H,W > 1 do BN
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out) if out.shape[-1] >1 else out
        out += identity  # Add shortcut connection
        out = self.relu(out)
        return out


class ResNet15(nn.Module):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.num_classes = num_classes
        # Initial layers (for 224x224 input)
        self.conv1 = nn.Conv2d(3, 12, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(12)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # Residual stages
        self.stage1 = self._make_stage(12, 12, num_blocks=1, stride=1)  # No downsampling
        self.stage2 = self._make_stage(12, 12, num_blocks=2, stride=2)  # First downsampling
        self.stage3 = self._make_stage(12, 24, num_blocks=2, stride=2)  # Second downsampling
        self.stage4 = self._make_stage(24, 24, num_blocks=2, stride=2)  # Final downsampling
        # Final layers (dynamically initialized)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc = None  # Dynamically created during first forward pass
    def _make_stage(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        # First block might need downsampling
        layers.append(BasicBlock(in_channels, out_channels, stride))
        # Subsequent blocks
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels, stride=stride))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Initial processing
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        # Residual stages
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        # Final classification
        x = self.avgpool(x)
        x = self.flatten(x)
        # Dynamic FC layer initialization (like LeNet)
        if self.fc is None:
            self.fc = nn.Linear(x.size(1), self.num_classes).to(x.device)
        x = self.fc(x)
        return x


def label_to_onehot(target, num_classes=100):
    target = torch.unsqueeze(target, 1)
    onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
    onehot_target.scatter_(1, target, 1)
    return onehot_target


def cross_entropy_for_onehot(pred, target):
    return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))


def train(grad_to_img_net, data_loader, leak_batch=1):
    grad_to_img_net.train()
    total_loss = 0
    total_num = 0
    for i, batch in enumerate(data_loader):
        optimizer.zero_grad()
        batch_num = len(batch[-1])  # ys???????
        batch_size = int(batch_num / leak_batch)
        batch_num = batch_size * leak_batch
        total_num += batch_num
        g_layers = batch[:-2]
        bias_gradient = batch[-2]
        ys = batch[-1]
        g_layers = [g[:batch_num] for g in g_layers]
        bias_gradient = bias_gradient[:batch_num]
        ys = ys[:batch_num]
        for g in g_layers:
            flattened_g = g.view(g.size(0), -1)

            # prune
            prune_rate = 0
            num_to_keep = int(flattened_g.size(1) * (1 - prune_rate))
            rank = torch.argsort(flattened_g.abs(), dim=1, descending=True)[:, :num_to_keep]
            batch_indices = torch.arange(g.size(0)).view(-1, 1).expand_as(rank)
            mask = torch.zeros_like(flattened_g)
            mask[batch_indices, rank] = 1
            mask = mask.view_as(g)
            g = g * mask

            # hash
            bs = g.size(0)
            model_size = g[0].numel()
            hash_dim = int(model_size * compress_rate)
            flattened_g = g.view(bs, model_size)
            shape = (model_size, hash_dim)
            hash_bin = torch.randint(0, hash_dim, (model_size,))
            i = torch.cat([torch.arange(model_size).long().unsqueeze(0), hash_bin.unsqueeze(0)], dim=0)
            hashed_matrix = torch.sparse_coo_tensor(i, torch.ones(model_size), shape)
            flattened_g = torch.sparse.mm(hashed_matrix.t(), flattened_g.t()).t().contiguous()
            g = flattened_g
            # hash for noise
            if gauss_noise > 0:
                bin_stat = torch.sparse.mm(hashed_matrix.t(),
                                           torch.ones([1, g.shape[1]]).t()).t().contiguous().squeeze()
                g += torch.randn(*g.shape) * gauss_noise * torch.sqrt(bin_stat)
        # Last layer's bias, no need to do hash
        if gauss_noise > 0:
            bias_gradient += torch.randn(*bias_gradient.shape) * gauss_noise
        # mean(1): averaged on leak_batch
        g_layers = [g.view(batch_size, leak_batch, *g.shape[1:]).mean(1) for g in
                    g_layers]
        bias_gradient = bias_gradient.view(batch_size, leak_batch, -1).mean(1)
        ys = ys.view(batch_size, leak_batch, -1)

        g_layers = [g.to(device) for g in g_layers]
        bias_gradient = bias_gradient.to(device)
        ys = ys.to(device)

        preds = grad_to_img_net(*g_layers, bias_grad=bias_gradient).view(batch_size, leak_batch, -1)
        batch_wise_mse = (torch.cdist(ys.view(batch_size, leak_batch, -1), preds) ** 2) / image_size
        loss = 0
        for mse_mat in batch_wise_mse:
            # mse_mat = torch.nan_to_num(mse_mat, nan=0.0, posinf=1e6, neginf=-1e6)
            row_ind, col_ind = linear_sum_assignment(mse_mat.detach().cpu().numpy())
            loss += mse_mat[row_ind, col_ind].mean()
        loss /= batch_size
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch_num
    total_loss = total_loss / len(data_loader.dataset)
    return total_loss


def test(grad_to_img_net, data_loader, leak_batch=1):
    grad_to_img_net.eval()
    total_loss = 0
    total_num = 0
    reconstructed_data = []
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            batch_num = len(batch[-1])
            batch_size = int(batch_num / leak_batch)
            batch_num = batch_size * leak_batch
            total_num += batch_num
            g_layers = batch[:-2]
            bias_gradient = batch[-2]
            ys = batch[-1]
            g_layers = [g[:batch_num] for g in g_layers]
            bias_gradient = bias_gradient[:batch_num]
            ys = ys[:batch_num]
            for g in g_layers:
                flattened_g = g.view(g.size(0), -1)

                # prune
                prune_rate = 0
                num_to_keep = int(flattened_g.size(1) * (1 - prune_rate))
                rank = torch.argsort(flattened_g.abs(), dim=1, descending=True)[:, :num_to_keep]
                batch_indices = torch.arange(g.size(0)).view(-1, 1).expand_as(rank)
                mask = torch.zeros_like(flattened_g)
                mask[batch_indices, rank] = 1
                mask = mask.view_as(g)
                g = g * mask

                # hash
                bs = g.size(0)
                model_size = g[0].numel()
                hash_dim = int(model_size * compress_rate)
                flattened_g = g.view(bs, model_size)
                shape = (model_size, hash_dim)
                hash_bin = torch.randint(0, hash_dim, (model_size,))
                i = torch.cat([torch.arange(model_size).long().unsqueeze(0), hash_bin.unsqueeze(0)], dim=0)
                hashed_matrix = torch.sparse_coo_tensor(i, torch.ones(model_size), shape)
                flattened_g = torch.sparse.mm(hashed_matrix.t(), flattened_g.t()).t().contiguous()
                g = flattened_g
                # hash for noise
                if gauss_noise > 0:
                    bin_stat = torch.sparse.mm(hashed_matrix.t(),
                                               torch.ones([1, g.shape[1]]).t()).t().contiguous().squeeze()
                    g += torch.randn(*g.shape) * gauss_noise * torch.sqrt(bin_stat)
            # Last layer's bias, no need to do hash
            if gauss_noise > 0:
                bias_gradient += torch.randn(*bias_gradient.shape) * gauss_noise
            # mean(1): averaged on leak_batch)
            g_layers = [g.view(batch_size, leak_batch, *g.shape[1:]).mean(1) for g in
                        g_layers]
            bias_gradient = bias_gradient.view(batch_size, leak_batch, -1).mean(1)
            ys = ys.view(batch_size, leak_batch, -1)

            g_layers = [g.to(device) for g in g_layers]
            bias_gradient = bias_gradient.to(device)
            ys = ys.to(device)

            preds = grad_to_img_net(*g_layers, bias_grad=bias_gradient).view(batch_size, leak_batch, -1)
            batch_wise_mse = (torch.cdist(ys, preds) ** 2) / image_size
            loss = 0
            for batch_id, mse_mat in enumerate(batch_wise_mse):

                row_ind, col_ind = linear_sum_assignment(mse_mat.detach().cpu().numpy())
                loss += mse_mat[row_ind, col_ind].sum()

                sorted_preds = preds[batch_id, col_ind]
                sorted_preds[row_ind] = preds[batch_id, col_ind]
                sorted_preds = sorted_preds.view(leak_batch, -1).detach().cpu()
                reconstructed_data.append(sorted_preds)
            total_loss += loss.item()
    reconstructed_data = torch.cat(reconstructed_data)
    if args.dataset in ["FashionMNIST", "MNIST"]:
        reconstructed_data = reconstructed_data.view(-1, 1, 28, 28)
    elif args.dataset.startswith("CIFAR10"):
        reconstructed_data = reconstructed_data.view(-1, 3, 32, 32)
    elif args.dataset.startswith("TinyImageNet"):
        reconstructed_data = reconstructed_data.view(-1, 3, 64, 64)
    elif args.dataset.startswith("ImageNet"):
        reconstructed_data = reconstructed_data.view(-1, 3, 224, 224)
    elif args.dataset.startswith("FER2013"):
        reconstructed_data = reconstructed_data.view(-1, 3, 48, 48)
    total_loss = total_loss / total_num

    return total_loss, reconstructed_data


def calculate_model_size(net):
    model_info = {}
    for i, parameters in enumerate(net.parameters()):
        layer_size = np.prod(parameters.size())
        param_shape = parameters.size()
        model_info[i] = {
            "size": layer_size,
            "shape": param_shape
        }
        print(f"Layer {i} - Weight shape: {param_shape}, Total size: {layer_size}")
    return model_info


# input the model shared among parties
if args.dataset in ["FashionMNIST", "MNIST"]:
    image_size = 1 * 28 * 28
    num_classes = 10
elif args.dataset.startswith("CIFAR10"):
    image_size = 3 * 32 * 32
    num_classes = 10
elif args.dataset == "TinyImageNet":
    image_size = 3 * 64 * 64
    num_classes = 200
elif args.dataset == "ImageNet":
    image_size = 3 * 224 * 224
    num_classes = 100
elif args.dataset == "FER2013":
    image_size = 3 * 48 * 48
    num_classes = 7


if args.shared_model == "LeNetMnist":
    net = LeNetMnist(input_channels=1, num_classes=num_classes).to(device)
    compress_rate = 1.0
    torch.manual_seed(1234)
    net.apply(weights_init)
    criterion = cross_entropy_for_onehot
elif args.shared_model == "LeNet":
    net = LeNet(num_classes=num_classes, num_layers=args.num_layers).to(device)
    compress_rate = 1.0
    torch.manual_seed(1234)
    net.apply(weights_init)
    criterion = cross_entropy_for_onehot
elif args.shared_model == "ResNet":
    net = ResNet15(num_classes=num_classes).to(device)
    compress_rate = 1.0
    torch.manual_seed(1234)
    net.apply(weights_init)
    criterion = cross_entropy_for_onehot


# Initialize leaked model based on num of layers
if args.dataset.startswith("CIFAR10"):
    dummy_input = torch.randn(1, 3, 32, 32).to(
        device)
if args.dataset.startswith("TinyImageNet"):
    dummy_input = torch.randn(1, 3, 64, 64).to(device)
if args.dataset.startswith("ImageNet"):
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
if args.dataset.startswith("FER2013"):
    dummy_input = torch.randn(1, 3, 48, 48).to(device)
net(dummy_input)
model_info = calculate_model_size(net)

# generate training / test dataset
if args.trainset == "full":
    checkpoint_name = f"{args.base_dir}/{args.dataset}_{args.shared_model}_{args.num_layers}_{args.model}_grad_to_img.pl"
else:
    checkpoint_name = f"{args.base_dir}/{args.dataset}_{args.shared_model}_{args.num_layers}_{args.model}_{args.trainset}_grad_to_img.pl"
if not os.path.exists(checkpoint_name):
    print("generating dataset...")
    if args.dataset == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        dst_train = datasets.MNIST("~/.torch", download=True, train=True, transform=transform)
        dst_test = datasets.MNIST("~/.torch", download=True, train=False, transform=transform)
    elif args.dataset == "FashionMNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        dst_train = datasets.FashionMNIST("~/.torch", download=True, train=True, transform=transform)
        dst_test = datasets.FashionMNIST("~/.torch", download=True, train=False, transform=transform)
    elif args.dataset == "CIFAR10":
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        dst_train = datasets.CIFAR10("~/.torch", download=True, train=True, transform=transform)
        dst_test = datasets.CIFAR10("~/.torch", download=True, train=False, transform=transform)
    elif args.dataset == "TinyImageNet":
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])
        train_dir = "/data/tiny-imagenet-200/train"
        val_dir = "/data/tiny-imagenet-200/test"
        dst_train = datasets.ImageFolder(train_dir, transform=transform)
        dst_test = datasets.ImageFolder(val_dir, transform=transform)
    elif args.dataset == "ImageNet":
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        train_dir = "/data/imagenet100/train"
        val_dir = "/data/imagenet100/val"
        dst_train = datasets.ImageFolder(train_dir, transform=transform)
        dst_test = datasets.ImageFolder(val_dir, transform=transform)
    elif args.dataset == "FER2013":
        # CIFAR-10尺寸配置（32x32）+ 兼容灰度数据
        transform = transforms.Compose([
            # transforms.Resize((32, 32)),  # 缩小到CIFAR尺寸
            transforms.Resize((48, 48)),
            transforms.Grayscale(num_output_channels=3),  # 可选：灰度转伪RGB
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 可选标准化
        ])
        train_dir = "/data/FER2013/train"
        val_dir = "/data/FER2013/test"
        dst_train = datasets.ImageFolder(train_dir, transform=transform)
        dst_test = datasets.ImageFolder(val_dir, transform=transform)


    # Slice to get only the first 1/percent of the dataset
    num_train_samples = len(dst_train)
    num_test_samples = len(dst_test)
    dst_train = torch.utils.data.Subset(dst_train, range(num_train_samples // args.percent))
    dst_test = torch.utils.data.Subset(dst_test, range(num_test_samples))
    train_loader = torch.utils.data.DataLoader(dataset=dst_train,
                                               batch_size=1,
                                               shuffle=False)
    test_loader = torch.utils.data.DataLoader(dataset=dst_test,
                                              batch_size=1,
                                              shuffle=False)


    def leakage_dataset(data_loader, net, criterion, image_size, num_classes):
        """
        Collects gradients of convolutional layers and the fully connected (fc) layer,
        excluding batch normalization, max pooling, and skip connections.
        """
        total_samples = len(data_loader.dataset)
        g_features = []
        # Extract only conv layers and fc layer
        conv_fc_layers = []
        for name, module in net.named_modules():
            if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
                conv_fc_layers.append((name, module))

        # Initialize g_features for weight gradients
        for _, layer in conv_fc_layers:
            if layer.weight is not None:
                g_features.append(torch.zeros([total_samples, *layer.weight.shape]))

        # Initialize bias_gradients for the last fc layer
        fc_layer = conv_fc_layers[-1][1]  # Last layer should be fc
        bias_gradients = torch.zeros([total_samples, fc_layer.out_features])

        targets = torch.zeros([total_samples, image_size])
        sample_index = 0

        for images, labels in data_loader:
            onehot_labels = label_to_onehot(labels, num_classes=num_classes)
            images, onehot_labels = images.to(device), onehot_labels.to(device)

            # Forward pass
            pred = net(images)
            loss = criterion(pred, onehot_labels)

            # Compute gradients
            dy_dx = torch.autograd.grad(loss,
                                        [layer.weight for _, layer in conv_fc_layers if layer.weight is not None] +
                                        [fc_layer.bias])

            batch_size_current = images.size(0)

            # Store gradients of conv + fc layers
            for j in range(len(g_features)):
                g_features[j][sample_index:sample_index + batch_size_current] = dy_dx[j].detach().cpu()

            # Store bias gradient for fc layer
            bias_gradients[sample_index:sample_index + batch_size_current] = dy_dx[-1].detach().cpu()

            # Store image data
            targets[sample_index:sample_index + batch_size_current] = images.view(batch_size_current, -1).cpu()

            sample_index += batch_size_current

        return (*g_features, bias_gradients, targets)


    # Parallel saving
    checkpoint = {}
    *g_train_layers, train_bias, train_targets = leakage_dataset(train_loader, net, criterion, image_size,
                                                                 num_classes)
    # checkpoint
    for idx, g_layer in enumerate(g_train_layers):
        checkpoint[f"g_layer_{idx}"] = g_layer
    checkpoint["train_targets"] = train_targets
    checkpoint["train_bias"] = train_bias
    *g_test_layers, test_bias, test_targets = leakage_dataset(test_loader, net, criterion, image_size,
                                                              num_classes)
    # checkpoint
    for idx, g_layer in enumerate(g_test_layers):
        checkpoint[f"g_test_layer_{idx}"] = g_layer
    checkpoint["test_targets"] = test_targets
    checkpoint["test_bias"] = test_bias
    torch.save(checkpoint, checkpoint_name)

    # Iterate through each layer in model_info and print the details
    for layer_idx, layer_info in model_info.items():
        print(f"Layer {layer_idx}:")
        print(f"  Size: {layer_info['size']}")
        print(f"  Shape: {layer_info['shape']}\n")
    # Print the shape of each layer in g_train_layers
    for idx, g_layer in enumerate(g_train_layers):
        print(f"Layer {idx} gradient shape: {g_layer.shape}")
    print(f"bias shape: {train_bias.shape}")
else:
    checkpoint = torch.load(checkpoint_name)
del net


print("loading dataset...")
# Initialize an empty list to store gradient tensors
layer_tensors = []
# Use a for loop to collect all gradient tensors from the checkpoint
for idx in range(len(checkpoint)):
    layer_key = f"g_layer_{idx}"
    if layer_key in checkpoint:
        layer_tensors.append(checkpoint[layer_key])
    else:
        break  # Stop if the layer does not exist
# Add the train targets and bias to the list
layer_tensors.append(checkpoint["train_bias"])
layer_tensors.append(checkpoint["train_targets"])
# Create the TensorDataset using the collected tensors
trainset = torch.utils.data.TensorDataset(*layer_tensors)
# Initialize an empty list to store gradient tensors
layer_tensors = []
# Use a for loop to collect all gradient tensors from the checkpoint
for idx in range(len(checkpoint)):
    layer_key = f"g_test_layer_{idx}"
    if layer_key in checkpoint:
        layer_tensors.append(checkpoint[layer_key])
    else:
        break  # Stop if the layer does not exist
# Add the train targets and bias to the list
layer_tensors.append(checkpoint["test_bias"])
layer_tensors.append(checkpoint["test_targets"])
# Create the TensorDataset using the collected tensors
testset = torch.utils.data.TensorDataset(*layer_tensors)


# leakage mode
leak_batch = 1
gauss_noise = 0
leak_mode_list = args.leak_mode.split("-")
for i in range(len(leak_mode_list)):
    if leak_mode_list[i] == "batch":
        leak_batch = int(leak_mode_list[i + 1])
    elif leak_mode_list[i] == "gauss":
        gauss_noise = float(leak_mode_list[i + 1])
print(leak_batch, gauss_noise)


# init the model
torch.manual_seed(0)
# selected_para = torch.randperm(model_size)[:int(model_size * compress_rate)]


if args.model.startswith("GIT"):
    hidden_size = int(args.model.split("-")[-1])


class GIT(nn.Module):
    def __init__(self, hidden_size):
        super(GIT, self).__init__()
        self.hidden_size = hidden_size
        self.fc_layers = nn.ModuleList()  
        self.weights = []  

    def forward(self, *g_list, bias_grad):
        batch_size = g_list[0].shape[0]  
        N = len(g_list)  # Number of gradient layers
        device = g_list[0].device  
        gN = g_list[-1]  # [bs, dN, dN-1]
        input_size = gN.shape[1] * gN.shape[2] + bias_grad.shape[1]
        if len(self.fc_layers) == 0:
            self.fc_layers.append(nn.Linear(input_size, self.hidden_size).to(device))
        x = torch.cat([gN.flatten(1), bias_grad.flatten(1)], dim=1)
        x = x.to(device)
        x = self.fc_layers[0](x)
        x = F.relu(x)
        for i in range(N - 2, -1, -1): 
            g_next = g_list[i + 1].view(g_list[i + 1].shape[0], g_list[i + 1].shape[1], -1)  # [bs, dN, dN-1]
            g_current = g_list[i].view(g_list[i].shape[0], g_list[i].shape[1], -1)  # [bs, dN-1, dN-2]
            x = torch.cat([x, g_next.view(batch_size, -1), g_current.view(batch_size, -1)], dim=1)
            current_input_size = x.shape[1] 
            x = x.to(device)
            if len(self.fc_layers) <= (N - 1 - i):
                if len(self.fc_layers) == N - 1:  
                    output_size = image_size * leak_batch  
                else:
                    output_size = self.hidden_size  

                layer = nn.Linear(current_input_size, output_size).to(device)  
                self.fc_layers.append(layer)
            x = self.fc_layers[N - 1 - i](x)  
            x = F.relu(x)
        return x


class GIT_shortcut(nn.Module):
    def __init__(self, hidden_size):
        super(GIT_shortcut, self).__init__()
        self.hidden_size = hidden_size
        self.fc_layers = nn.ModuleList()
        self.shortcut_layers = nn.ModuleList() 
        self.weights = []

    def forward(self, *g_list, bias_grad):
        batch_size = g_list[0].shape[0]
        N = len(g_list) 
        device = g_list[0].device
        shortcut_indices = [i for i, g in enumerate(g_list) if g.shape[3:] == (1, 1)]
        g_list = [g for i, g in enumerate(g_list) if i not in shortcut_indices]
        N = len(g_list)
        gN = g_list[-1]  # [batch_size, dN, dN-1]
        input_size = gN.shape[1] * gN.shape[2] + bias_grad.shape[1]
        if len(self.fc_layers) == 0:
            self.fc_layers.append(nn.Linear(input_size, self.hidden_size).to(device))

        x = torch.cat([gN.flatten(1), bias_grad.flatten(1)], dim=1)
        # ablation study for bias
        # x = gN.flatten(1)
        x = x.to(device)
        x = self.fc_layers[0](x)
        x = F.relu(x)
        prev_x = [x]  
        for i in range(N - 2, -1, -1):  
            g_next = g_list[i + 1].view(batch_size, g_list[i + 1].shape[1], -1)  # [bs, dN, dN-1]
            g_current = g_list[i].view(batch_size, g_list[i].shape[1], -1)  # [bs, dN-1, dN-2]

            x = torch.cat([x, g_next.view(batch_size, -1), g_current.view(batch_size, -1)], dim=1)
            current_input_size = x.shape[1]
            x = x.to(device)

            if len(self.fc_layers) <= (N - 1 - i):
                if len(self.fc_layers) == N - 1:
                    output_size = image_size * leak_batch
                else:
                    output_size = self.hidden_size
                layer = nn.Linear(current_input_size, output_size).to(device)
                self.fc_layers.append(layer)
            x = self.fc_layers[N - 1 - i](x)

            shortcut_x = None
            if (i + 2) in shortcut_indices: 
                shortcut_x = prev_x[-2]  
                if shortcut_x.shape[1] != x.shape[1]:  
                    if len(self.shortcut_layers) <= (N - 1 - i):
                        self.shortcut_layers.append(nn.Linear(shortcut_x.shape[1], x.shape[1]).to(device))
                    shortcut_x = self.shortcut_layers[N - 1 - i](shortcut_x)

            if shortcut_x is not None:
                x = x + shortcut_x 

            x = F.relu(x)
            prev_x.append(x) 

        return x


# Define the GIT
if args.shared_model == "ResNet":
    grad_to_img_net = GIT_shortcut(hidden_size=hidden_size).to(device)
else:
    grad_to_img_net = GIT(hidden_size=hidden_size).to(device)


# load the dataloader (input-gradient pairs)
batch_size = args.batch_size
train_loader = torch.utils.data.DataLoader(dataset=trainset,
                                           batch_size=(batch_size * leak_batch),
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=testset,
                                          batch_size=(batch_size * leak_batch),
                                          shuffle=False)


# Initialize GIT
for i, batch in enumerate(train_loader):
    batch_num = len(batch[-1])
    batch_size = int(batch_num / leak_batch)
    batch_num = batch_size * leak_batch
    g_layers = batch[:-2]
    bias_gradient = batch[-2]
    ys = batch[-1]
    g_layers = [g.to(device)[:batch_num] for g in g_layers]
    bias_gradient = bias_gradient.to(device)[:batch_num]
    ys = ys.to(device)[:batch_num]
    if gauss_noise > 0:
        for idx in range(len(g_layers)):
            g_layers[idx] += torch.randn(*g_layers[idx].shape).to(device) * gauss_noise

    g_layers = [g.view(batch_size, leak_batch, *g.shape[1:]).mean(1) for g in
                g_layers]
    bias_gradient = bias_gradient.view(batch_size, leak_batch, -1).mean(1)
    ys = ys.view(batch_size, leak_batch, -1)

    preds = grad_to_img_net(*g_layers, bias_grad=bias_gradient).view(batch_size, leak_batch, -1)
    break


# Calculate number of parameters of GIT (must after GIT initialization !)
num_params = sum(p.numel() for p in grad_to_img_net.parameters())
print(f"Total number of parameters: {num_params}")


# training set-up
lr = args.lr
epochs = args.epochs
optimizer = torch.optim.Adam(grad_to_img_net.parameters(), lr=lr)
# optimizer = torch.optim.Adam(grad_to_img_net.parameters(), lr=lr, weight_decay=1e-5)


# reformate the gt_data
gt_data = checkpoint["test_targets"]
# output the model shared among clients
if args.dataset in ["FashionMNIST", "MNIST"]:
    gt_data = gt_data.view(-1, 1, 28, 28)
elif args.dataset.startswith("CIFAR10"):
    gt_data = gt_data.view(-1, 3, 32, 32)
if args.dataset == "TinyImageNet":
    gt_data = gt_data.view(-1, 3, 64, 64)
elif args.dataset == "ImageNet":
    gt_data = gt_data.view(-1, 3, 224, 224)
elif args.dataset == "FER2013":
    gt_data = gt_data.view(-1, 3, 48, 48)
del checkpoint


class WrappedModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, g_list, bias_grad):
        return self.model(*g_list, bias_grad=bias_grad)
wrapped_model = WrappedModel(grad_to_img_net)
flops_forward, _ = profile(
    wrapped_model,
    inputs=(g_layers, bias_gradient)  
)
flops_per_batch = flops_forward * 3
total_train_steps = args.epochs * len(train_loader)
total_flops = flops_per_batch * total_train_steps
total_train_time = 0.0
last_test_time = 0.0

best_test_loss = 10000
best_state_dict = None
for epoch in range(args.epochs):
    train_start = time.time()
    train_loss = train(grad_to_img_net, train_loader, leak_batch=leak_batch)
    train_end = time.time()
    total_train_time += (train_end - train_start)  
    test_start = time.time()
    test_loss, reconstructed_imgs = test(grad_to_img_net, test_loader, leak_batch=leak_batch)
    test_end = time.time()
    if epoch == args.epochs - 1:
        last_test_time = test_end - test_start
    grad_to_img_net = grad_to_img_net.cpu()
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        best_state_dict = deepcopy(grad_to_img_net.to("cpu").state_dict())
    logger.info(f"epoch: {epoch}, train_loss: {train_loss}, test_loss: {test_loss}, best_test_loss: {best_test_loss}")
    if (epoch + 1) == int(0.25 * args.epochs):
        for g in optimizer.param_groups:
            g['lr'] *= 0.1
    if (epoch + 1) == int(0.75 * args.epochs):
        for g in optimizer.param_groups:
            g['lr'] *= 0.5
    grad_to_img_net.to(device)
checkpoint = {}

logger.info(f"Total training time: {total_train_time} seconds")
logger.info(f"Last test time: {last_test_time} seconds")
checkpoint["total_flops"] = total_flops # 新增总FLOPs
checkpoint["total_train_time"] = total_train_time # 新增总训练时间
checkpoint["last_test_time"] = last_test_time # 最后测试时间
checkpoint["flops_per_sample"] = flops_forward # 新增单样本前向FLOPs
checkpoint["avg_epoch_time"] = total_train_time/args.epochs # 新增平均epoch耗时
checkpoint["train_loss"] = train_loss
checkpoint["test_loss"] = test_loss
checkpoint["state_dict"] = grad_to_img_net.state_dict()
checkpoint["best_test_loss"] = best_test_loss
checkpoint["best_state_dict"] = best_state_dict
checkpoint["reconstructed_imgs"] = reconstructed_imgs
checkpoint["gt_data"] = gt_data
if args.trainset == "full":
    torch.save(checkpoint,
               f"{args.base_dir}/{args.dataset}_{args.shared_model}_{args.num_layers}_{args.model}_{args.leak_mode}_{args.lr}_{args.epochs}_{args.batch_size}.pt")
else:
    torch.save(checkpoint,
               f"{args.base_dir}/{args.dataset}_{args.trainset}_{args.shared_model}_{args.num_layers}_{args.model}_{args.leak_mode}_{args.lr}_{args.epochs}_{args.batch_size}.pt")


# Print computational cost
print(f"total_flops: {total_flops}")
print(f"total_train_time: {total_train_time}")
print(f"last_test_time: {last_test_time}")
print(f"flops_per_sample: {flops_forward }")
print(f"avg_epoch_time: {total_train_time/args.epochs}")


# Load checkpoint (saved model, reconstructed images, etc)
checkpoint_name = f"{args.base_dir}/{args.dataset}_{args.shared_model}_{args.num_layers}_{args.model}_{args.leak_mode}_{args.lr}_{args.epochs}_{args.batch_size}.pt"
checkpoint = torch.load(checkpoint_name)
if args.dataset in ["FashionMNIST", "MNIST"]:
    reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 1, 28, 28).cpu()
    gt_data = checkpoint["gt_data"]
    gt_data = gt_data.view(-1, 1, 28, 28).cpu()
elif args.dataset.startswith("CIFAR10"):
    reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 3, 32, 32).cpu()
    gt_data = checkpoint["gt_data"]
    gt_data = gt_data.view(-1, 3, 32, 32).cpu()
elif args.dataset.startswith("TinyImageNet"):
    reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 3, 64, 64).cpu()
    gt_data = checkpoint["gt_data"]
    gt_data = gt_data.view(-1, 3, 64, 64).cpu()
elif args.dataset.startswith("ImageNet"):
    reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 3, 224, 224).cpu()
    gt_data = checkpoint["gt_data"]
    gt_data = gt_data.view(-1, 3, 224, 224).cpu()
elif args.dataset.startswith("FER2013"):
    reconstructed_data = checkpoint["reconstructed_imgs"].view(-1, 3, 48, 48).cpu()
    gt_data = checkpoint["gt_data"]
    gt_data = gt_data.view(-1, 3, 48, 48).cpu()


# Calculate different metrics using saved reconstructed images in checkpoint
# SSIM
ssim_module = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
# LPIPS
lpips_model = lpips.LPIPS(net='alex', verbose=False).to(device)
lpips_model.eval()

mse_list = []
psnr_list = []
ssim_list = []
lpips_list = []
for i in range(len(gt_data)):
    img_gt = gt_data[i].to(device)
    img_recon = reconstructed_data[i].to(device)
    # MSE
    mse = torch.mean((img_gt - img_recon) ** 2)
    mse_list.append(mse.item())
    # PSNR
    max_pixel = 1.0 
    psnr = 10 * torch.log10(max_pixel ** 2 / (mse + 1e-10))
    psnr_list.append(psnr.item())
    # SSIM
    ssim_val = ssim_module(
        img_gt.unsqueeze(0),  
        img_recon.unsqueeze(0)
    )
    ssim_list.append(ssim_val.item())
    # LPIPS

    if img_gt.shape[0] == 1:  
        img_gt_lpips = img_gt.repeat(3, 1, 1)
        img_recon_lpips = img_recon.repeat(3, 1, 1)
    else:
        img_gt_lpips = img_gt
        img_recon_lpips = img_recon
    img_gt_lpips = img_gt_lpips.unsqueeze(0) * 2 - 1
    img_recon_lpips = img_recon_lpips.unsqueeze(0) * 2 - 1
    with torch.no_grad():
        lpips_score = lpips_model(img_gt_lpips, img_recon_lpips)
    lpips_list.append(lpips_score.item())

mean_mse = np.mean(mse_list)
mean_psnr = np.mean(psnr_list)
mean_ssim = np.mean(ssim_list)
mean_lpips = np.mean(lpips_list)
print(f"Mean MSE: {mean_mse:.6f}")
print(f"Mean PSNR: {mean_psnr:.2f} dB")
print(f"Mean SSIM: {mean_ssim:.4f}")
print(f"Mean LPIPS: {mean_lpips:.4f}")


# Plot first 8
fig, axes = plt.subplots(2, 8, figsize=(15, 6))
for i in range(8):
    image = gt_data[i].numpy()
    reconstructed_image = reconstructed_data[i].numpy()
    # tensor to (height, width, channels)
    image = image.transpose(1, 2, 0)
    reconstructed_image = reconstructed_image.transpose(1, 2, 0)
    axes[0, i].imshow(image, cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(reconstructed_image, cmap='gray')
    axes[1, i].axis('off')
plt.savefig(
    f"{args.base_dir}/{args.dataset}_{args.leak_mode}_{args.epochs}_{args.shared_model}_{args.num_layers}_{args.model}.pdf",
    bbox_inches='tight')
plt.show()


# Plot best 8
# Get indices of the 8 smallest MSE values (or)
sorted_indices = np.argsort(lpips_list)[:8]
# Create a figure with 2 rows and 8 columns
fig, axes = plt.subplots(2, 8, figsize=(20, 5))
# Plot the pairs with the lowest MSE
for i, idx in enumerate(sorted_indices):
    # Convert tensors to numpy arrays and transpose dimensions
    image = gt_data[idx].cpu().numpy().transpose(1, 2, 0)
    reconstructed_image = reconstructed_data[idx].cpu().numpy().transpose(1, 2, 0)
    # Plot original image
    axes[0, i].imshow(image, cmap='gray')
    axes[0, i].axis('off')
    # Plot reconstructed image
    axes[1, i].imshow(reconstructed_image, cmap='gray')
    axes[1, i].axis('off')
# Save and show the plot
plt.savefig(
    f"{args.base_dir}/{args.dataset}_{args.leak_mode}_{args.epochs}_{args.shared_model}_{args.num_layers}_{args.model}_best8.pdf",
    bbox_inches='tight')